// Copyright 2022 Huawei Cloud Computing Technology Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cerrno>
#include <cstring>
#include <cstdio>
#include <cstdlib>
#include <sys/socket.h>
#include <linux/in.h>
#include <linux/tcp.h>
#include <unistd.h>
#include <openssl/ssl.h>
#include "CasTcpSocket.h"
#include "../cas_common/CasLog.h"

using namespace std;

CasTcpSocket::CasTcpSocket()
{
    m_fd = -1;
    m_status = SOCKET_STATUS_INIT;
    m_localIp = 0;
    m_remoteIp = 0;
    m_localPort = 0;
    m_remotePort = 0;
    m_socketOption = 0;
    m_eventNotice = nullptr;
    m_ssl = nullptr;
    m_ctx = nullptr;
    this->msg_no_map.clear();
    this->msg_valid_map.clear();
#if USE_TLS
    if (SSL_library_init() != 1) {
        ERR("Init ssl library failed %s.", ERR_reason_error_string(ERR_get_error()));
        return;
    }
    OpenSSL_add_all_algorithms();
    SSL_load_error_strings();
    m_ctx = SSL_CTX_new(TLS_client_method());
    if (m_ctx == nullptr) {
        ERR("Failed to new ssl ctx %s.", ERR_reason_error_string(ERR_get_error()));
        return;
    }
    SSL_CTX_set_min_proto_version(m_ctx, TLS1_2_VERSION);
    SSL_CTX_set_verify(m_ctx, SSL_VERIFY_NONE, nullptr);

    const char allowedCiphers[] = "ECDHE+ECDSA+AESGCM:!aNULL:!eNULL:!ADH:!SHA:@STRENGTH";
    if (SSL_CTX_set_cipher_list(m_ctx, allowedCiphers) != 1) {
        ERR("Failed set cipher list %s.", ERR_reason_error_string(ERR_get_error()));
        return;
    }
    SSL_CTX_set_mode(m_ctx, SSL_MODE_AUTO_RETRY);
    m_ssl = SSL_new(m_ctx);
    if (m_ssl == nullptr) {
        ERR("Create ssl failed %s.", ERR_reason_error_string(ERR_get_error()));
        return;
    }
#endif
}

CasTcpSocket::~CasTcpSocket()
{
    m_status = SOCKET_STATUS_EXIT;
    if (m_fd != -1) {
        shutdown(m_fd, SHUT_RDWR);
        close(m_fd);
    }

    m_fd = -1;

    if (m_eventNotice != nullptr) {
        delete m_eventNotice;
        m_eventNotice = nullptr;
    }
#if USE_TLS
    if (m_ctx != nullptr) {
        SSL_CTX_free(m_ctx);
        m_ctx = nullptr;
    }
    if (m_ssl != nullptr) {
        SSL_free(m_ssl);
        m_ssl = nullptr;
    }
#endif
}

int CasTcpSocket::Send(void *pkt, size_t size)
{
    int byte1;
    int socketFd = this->m_fd;

    if (pkt == nullptr) {
        ERR("Pkt is null.");
        return SOCKET_SEND_FAIL_RETRY;
    }
    m_socketLock.lock();

    while (this->m_status == SOCKET_STATUS_RUNNING) {
#if USE_TLS
        byte1 = SSL_write(m_ssl, pkt, size);
#else
        byte1 = ::send(socketFd, pkt, size, 0);
#endif
        if ((int)size == byte1) {
            m_socketLock.unlock();
            return size;
        } else if ((0 <= byte1) && ((int)size > byte1)) {
            size -= static_cast<uint32_t>(byte1);
            pkt = (void *)((unsigned char *)pkt + byte1);
            continue;
        } else if (!((EAGAIN == errno) || (EWOULDBLOCK == errno) || (EINTR == errno) || (ETIMEDOUT == errno))) {
            break;
        }
    }

    m_socketLock.unlock();

    ERR("Send fail, size = %zu ret = %d, errno (%d): %s.", size, byte1, errno, strerror(errno));
    if (this->GetStatus() != SOCKET_STATUS_DISCONNECT) {
        this->SetStatus(SOCKET_STATUS_DISCONNECT);
    }

    return SOCKET_SEND_FAIL_DISCONNECT;
}

int CasTcpSocket::Recv(void *pkt, size_t size)
{
    int flag = m_socketOption & SOCKET_OPTION_BITSET_QUICK_ACK;
    int byte1;
    int socketFd = this->m_fd;

    if (pkt == nullptr) {
        ERR("Pkt is null.");
        return SOCKET_RECV_FAIL_RETRY;
    }

#if USE_TLS
    byte1 = SSL_read(m_ssl, pkt, size);
#else
    byte1 = ::recv(socketFd, pkt, size, 0);
#endif
    if (0 >= byte1) {
        flag = 0;
        byte1 = SOCKET_RECV_FAIL_RETRY;
        if (!((0 > byte1) && ((0 == errno) || (EAGAIN == errno) || (EWOULDBLOCK == errno) || (EINTR == errno) ||
            (ETIMEDOUT == errno)))) {
            ERR("Recv fail, ret = %d. errno (%d): %s.", byte1, errno, strerror(errno));
            if (this->GetStatus() != SOCKET_STATUS_DISCONNECT) {
                this->SetStatus(SOCKET_STATUS_DISCONNECT);
            }
            byte1 = SOCKET_RECV_FAIL_DISCONNECT;
        }
    }

    if (flag) {
        int a[] = {1};
        setsockopt(socketFd, IPPROTO_TCP, TCP_QUICKACK, a, sizeof(int));
    }

    return byte1;
}

int CasTcpSocket::ConfigSSL()
{
#if USE_TLS
    if (m_ctx == nullptr) {
        ERR("Ctx is null.");
        return SSL_CONFIG_ERROR;
    }
    if (m_ssl == nullptr) {
        ERR("Ssl is null.");
        SSL_CTX_free(m_ctx);
        m_ctx = nullptr;
        return SSL_CONFIG_ERROR;
    }
    if (SSL_set_fd(m_ssl, m_fd) != 1) {
        ERR("Failed to set fd %s", ERR_reason_error_string(ERR_get_error()));
        SSL_CTX_free(m_ctx);
        SSL_free(m_ssl);
        close(m_fd);
        m_ctx = nullptr;
        m_ssl = nullptr;
        m_fd = -1;
        return SSL_CONFIG_ERROR;
    }
    int ret = SSL_connect(m_ssl);
    if (ret != 1) {
        ERR("Connect failed %s , ret = %d, errno = %d", ERR_reason_error_string(ERR_get_error()), ret, errno);
        int status = SSL_get_error(m_ssl, ret);
        SSL_CTX_free(m_ctx);
        SSL_free(m_ssl);
        close(m_fd);
        m_ctx = nullptr;
        m_ssl = nullptr;
        m_fd = -1;
        if (status == SSL_ERROR_SYSCALL && (errno == 0 || errno == ECONNRESET)) {
            return SSL_CONFIG_SOCKET_CLOSE;
        } else {
            return SSL_CONFIG_ERROR;
        }
    }
    return SSL_CONFIG_SUCCESS;
#endif
};
